{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Sample Submission.csv\n",
      "/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Test.csv\n",
      "/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Train.csv\n"
     ]
    }
   ],
   "source": [
    "# This Python 3 environment comes with many helpful analytics libraries installed\n",
    "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
    "# For example, here's several helpful packages to load\n",
    "\n",
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "\n",
    "# Input data files are available in the read-only \"../input/\" directory\n",
    "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
    "\n",
    "import os\n",
    "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
    "    for filename in filenames:\n",
    "        print(os.path.join(dirname, filename))\n",
    "\n",
    "# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
    "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [],
   "source": [
    "train = pd.read_csv(\"/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Train.csv\")\n",
    "test = pd.read_csv(\"/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Test.csv\")\n",
    "samp = pd.read_csv(\"/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Sample Submission.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text_ID</th>\n",
       "      <th>Product_Description</th>\n",
       "      <th>Product_Type</th>\n",
       "      <th>Sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3057</td>\n",
       "      <td>The Web DesignerÛªs Guide to iOS (and Android...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6254</td>\n",
       "      <td>RT @mention Line for iPad 2 is longer today th...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>8212</td>\n",
       "      <td>Crazy that Apple is opening a temporary store ...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4422</td>\n",
       "      <td>The lesson from Google One Pass: In this digit...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5526</td>\n",
       "      <td>RT @mention At the panel: &amp;quot;Your mom has a...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Text_ID                                Product_Description  Product_Type  \\\n",
       "0     3057  The Web DesignerÛªs Guide to iOS (and Android...             9   \n",
       "1     6254  RT @mention Line for iPad 2 is longer today th...             9   \n",
       "2     8212  Crazy that Apple is opening a temporary store ...             9   \n",
       "3     4422  The lesson from Google One Pass: In this digit...             9   \n",
       "4     5526  RT @mention At the panel: &quot;Your mom has a...             9   \n",
       "\n",
       "   Sentiment  \n",
       "0          2  \n",
       "1          2  \n",
       "2          2  \n",
       "3          2  \n",
       "4          2  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9    4070\n",
       "6     665\n",
       "2     465\n",
       "7     327\n",
       "3     300\n",
       "5     213\n",
       "8     194\n",
       "1      59\n",
       "0      52\n",
       "4      19\n",
       "Name: Product_Type, dtype: int64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train['Product_Type'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2    3765\n",
       "3    2089\n",
       "1     399\n",
       "0     111\n",
       "Name: Sentiment, dtype: int64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train['Sentiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Product_Type  Sentiment\n",
       "0             3              47\n",
       "              1               4\n",
       "              2               1\n",
       "1             3              53\n",
       "              1               5\n",
       "              2               1\n",
       "2             3             379\n",
       "              1              69\n",
       "              2              15\n",
       "              0               2\n",
       "3             3             240\n",
       "              1              49\n",
       "              2              10\n",
       "              0               1\n",
       "4             3              19\n",
       "5             3             171\n",
       "              1              36\n",
       "              2               6\n",
       "6             3             563\n",
       "              1              84\n",
       "              2              16\n",
       "              0               2\n",
       "7             3             276\n",
       "              1              43\n",
       "              2               8\n",
       "8             3             122\n",
       "              1              65\n",
       "              2               6\n",
       "              0               1\n",
       "9             2            3702\n",
       "              3             219\n",
       "              0             105\n",
       "              1              44\n",
       "Name: Sentiment, dtype: int64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.groupby(['Product_Type'])['Sentiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sentiment  Product_Type\n",
       "0          9                105\n",
       "           2                  2\n",
       "           6                  2\n",
       "           3                  1\n",
       "           8                  1\n",
       "1          6                 84\n",
       "           2                 69\n",
       "           8                 65\n",
       "           3                 49\n",
       "           9                 44\n",
       "           7                 43\n",
       "           5                 36\n",
       "           1                  5\n",
       "           0                  4\n",
       "2          9               3702\n",
       "           6                 16\n",
       "           2                 15\n",
       "           3                 10\n",
       "           7                  8\n",
       "           5                  6\n",
       "           8                  6\n",
       "           0                  1\n",
       "           1                  1\n",
       "3          6                563\n",
       "           2                379\n",
       "           7                276\n",
       "           3                240\n",
       "           9                219\n",
       "           5                171\n",
       "           8                122\n",
       "           1                 53\n",
       "           0                 47\n",
       "           4                 19\n",
       "Name: Product_Type, dtype: int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.groupby(['Sentiment'])['Product_Type'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9    1731\n",
       "6     281\n",
       "2     196\n",
       "7     143\n",
       "3     130\n",
       "8     103\n",
       "5      80\n",
       "0      26\n",
       "1      22\n",
       "4      16\n",
       "Name: Product_Type, dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test['Product_Type'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting simpletransformers\n",
      "  Downloading simpletransformers-0.47.7-py3-none-any.whl (208 kB)\n",
      "\u001b[K     |████████████████████████████████| 208 kB 2.6 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting seqeval\n",
      "  Downloading seqeval-0.0.12.tar.gz (21 kB)\n",
      "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (1.1.1)\n",
      "Requirement already satisfied: scipy in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (1.4.1)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (2.23.0)\n",
      "Collecting tqdm>=4.47.0\n",
      "  Downloading tqdm-4.48.2-py2.py3-none-any.whl (68 kB)\n",
      "\u001b[K     |████████████████████████████████| 68 kB 3.8 MB/s  eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (0.23.2)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (1.18.5)\n",
      "Requirement already satisfied: tensorboardx in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (2.1)\n",
      "Collecting streamlit\n",
      "  Downloading streamlit-0.66.0-py2.py3-none-any.whl (7.2 MB)\n",
      "\u001b[K     |████████████████████████████████| 7.2 MB 10.1 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: wandb in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (0.9.5)\n",
      "Requirement already satisfied: tokenizers in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (0.7.0)\n",
      "Collecting transformers>=3.0.2\n",
      "  Downloading transformers-3.1.0-py3-none-any.whl (884 kB)\n",
      "\u001b[K     |████████████████████████████████| 884 kB 54.6 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: regex in /opt/conda/lib/python3.7/site-packages (from simpletransformers) (2020.4.4)\n",
      "Requirement already satisfied: Keras>=2.2.4 in /opt/conda/lib/python3.7/site-packages (from seqeval->simpletransformers) (2.4.3)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->simpletransformers) (2.8.1)\n",
      "Requirement already satisfied: pytz>=2017.2 in /opt/conda/lib/python3.7/site-packages (from pandas->simpletransformers) (2019.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->simpletransformers) (2020.6.20)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->simpletransformers) (2.9)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->simpletransformers) (3.0.4)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->simpletransformers) (1.24.3)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->simpletransformers) (2.1.0)\n",
      "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->simpletransformers) (0.14.1)\n",
      "Requirement already satisfied: protobuf>=3.8.0 in /opt/conda/lib/python3.7/site-packages (from tensorboardx->simpletransformers) (3.13.0)\n",
      "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from tensorboardx->simpletransformers) (1.14.0)\n",
      "Requirement already satisfied: altair>=3.2.0 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (4.1.0)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (7.2.0)\n",
      "Requirement already satisfied: click>=7.0 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (7.1.1)\n",
      "Requirement already satisfied: blinker in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (1.4)\n",
      "Collecting base58\n",
      "  Downloading base58-2.0.1-py3-none-any.whl (4.3 kB)\n",
      "Collecting pydeck>=0.1.dev5\n",
      "  Downloading pydeck-0.5.0b1-py2.py3-none-any.whl (4.4 MB)\n",
      "\u001b[K     |████████████████████████████████| 4.4 MB 48.1 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting cachetools>=4.0\n",
      "  Downloading cachetools-4.1.1-py3-none-any.whl (10 kB)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (20.1)\n",
      "Requirement already satisfied: tzlocal in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (2.1)\n",
      "Requirement already satisfied: toml in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (0.10.0)\n",
      "Requirement already satisfied: tornado>=5.0 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (5.0.2)\n",
      "Collecting enum-compat\n",
      "  Downloading enum_compat-0.0.3-py3-none-any.whl (1.3 kB)\n",
      "Requirement already satisfied: watchdog in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (0.10.2)\n",
      "Requirement already satisfied: pyarrow in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (0.16.0)\n",
      "Requirement already satisfied: botocore>=1.13.44 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (1.17.48)\n",
      "Collecting astor\n",
      "  Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)\n",
      "Collecting validators\n",
      "  Downloading validators-0.18.1-py3-none-any.whl (19 kB)\n",
      "Requirement already satisfied: boto3 in /opt/conda/lib/python3.7/site-packages (from streamlit->simpletransformers) (1.14.48)\n",
      "Requirement already satisfied: configparser>=3.8.1 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (5.0.0)\n",
      "Requirement already satisfied: sentry-sdk>=0.4.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (0.17.0)\n",
      "Requirement already satisfied: shortuuid>=0.5.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (1.0.1)\n",
      "Requirement already satisfied: PyYAML>=3.10 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (5.3.1)\n",
      "Requirement already satisfied: psutil>=5.0.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (5.7.0)\n",
      "Requirement already satisfied: docker-pycreds>=0.4.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (0.4.0)\n",
      "Requirement already satisfied: subprocess32>=3.5.3 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (3.5.4)\n",
      "Requirement already satisfied: GitPython>=1.0.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (3.1.1)\n",
      "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (7.352.0)\n",
      "Requirement already satisfied: gql==0.2.0 in /opt/conda/lib/python3.7/site-packages (from wandb->simpletransformers) (0.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers>=3.0.2->simpletransformers) (3.0.10)\n",
      "Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers>=3.0.2->simpletransformers) (0.0.43)\n",
      "Requirement already satisfied: sentencepiece!=0.1.92 in /opt/conda/lib/python3.7/site-packages (from transformers>=3.0.2->simpletransformers) (0.1.91)\n",
      "Requirement already satisfied: h5py in /opt/conda/lib/python3.7/site-packages (from Keras>=2.2.4->seqeval->simpletransformers) (2.10.0)\n",
      "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from protobuf>=3.8.0->tensorboardx->simpletransformers) (46.1.3.post20200325)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.7/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (2.11.2)\n",
      "Requirement already satisfied: jsonschema in /opt/conda/lib/python3.7/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (3.2.0)\n",
      "Requirement already satisfied: toolz in /opt/conda/lib/python3.7/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (0.10.0)\n",
      "Requirement already satisfied: entrypoints in /opt/conda/lib/python3.7/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (0.3)\n",
      "Requirement already satisfied: traitlets>=4.3.2 in /opt/conda/lib/python3.7/site-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (4.3.3)\n",
      "Requirement already satisfied: ipywidgets>=7.0.0 in /opt/conda/lib/python3.7/site-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (7.5.1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting ipykernel>=5.1.2; python_version >= \"3.4\"\n",
      "  Downloading ipykernel-5.3.4-py3-none-any.whl (120 kB)\n",
      "\u001b[K     |████████████████████████████████| 120 kB 61.3 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->streamlit->simpletransformers) (2.4.7)\n",
      "Requirement already satisfied: pathtools>=0.1.1 in /opt/conda/lib/python3.7/site-packages (from watchdog->streamlit->simpletransformers) (0.1.2)\n",
      "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /opt/conda/lib/python3.7/site-packages (from botocore>=1.13.44->streamlit->simpletransformers) (0.10.0)\n",
      "Requirement already satisfied: docutils<0.16,>=0.10 in /opt/conda/lib/python3.7/site-packages (from botocore>=1.13.44->streamlit->simpletransformers) (0.15.2)\n",
      "Requirement already satisfied: decorator>=3.4.0 in /opt/conda/lib/python3.7/site-packages (from validators->streamlit->simpletransformers) (4.4.2)\n",
      "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /opt/conda/lib/python3.7/site-packages (from boto3->streamlit->simpletransformers) (0.3.3)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/lib/python3.7/site-packages (from GitPython>=1.0.0->wandb->simpletransformers) (4.0.4)\n",
      "Requirement already satisfied: promise<3,>=2.0 in /opt/conda/lib/python3.7/site-packages (from gql==0.2.0->wandb->simpletransformers) (2.3)\n",
      "Requirement already satisfied: graphql-core<2,>=0.5.0 in /opt/conda/lib/python3.7/site-packages (from gql==0.2.0->wandb->simpletransformers) (1.1)\n",
      "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/lib/python3.7/site-packages (from jinja2->altair>=3.2.0->streamlit->simpletransformers) (1.1.1)\n",
      "Requirement already satisfied: attrs>=17.4.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->altair>=3.2.0->streamlit->simpletransformers) (19.3.0)\n",
      "Requirement already satisfied: pyrsistent>=0.14.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->altair>=3.2.0->streamlit->simpletransformers) (0.16.0)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /opt/conda/lib/python3.7/site-packages (from jsonschema->altair>=3.2.0->streamlit->simpletransformers) (1.6.0)\n",
      "Requirement already satisfied: ipython-genutils in /opt/conda/lib/python3.7/site-packages (from traitlets>=4.3.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.0)\n",
      "Requirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.7/site-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.0.6)\n",
      "Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /opt/conda/lib/python3.7/site-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (7.13.0)\n",
      "Requirement already satisfied: widgetsnbextension~=3.5.0 in /opt/conda/lib/python3.7/site-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.5.1)\n",
      "Requirement already satisfied: jupyter-client in /opt/conda/lib/python3.7/site-packages (from ipykernel>=5.1.2; python_version >= \"3.4\"->pydeck>=0.1.dev5->streamlit->simpletransformers) (6.1.3)\n",
      "Requirement already satisfied: smmap<4,>=3.0.1 in /opt/conda/lib/python3.7/site-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb->simpletransformers) (3.0.2)\n",
      "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema->altair>=3.2.0->streamlit->simpletransformers) (3.1.0)\n",
      "Requirement already satisfied: jupyter-core in /opt/conda/lib/python3.7/site-packages (from nbformat>=4.2.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.6.3)\n",
      "Requirement already satisfied: backcall in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.1.0)\n",
      "Requirement already satisfied: pexpect; sys_platform != \"win32\" in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.8.0)\n",
      "Requirement already satisfied: pygments in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (2.6.1)\n",
      "Requirement already satisfied: jedi>=0.10 in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.15.2)\n",
      "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.5)\n",
      "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.0.5)\n",
      "Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.7/site-packages (from widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.5.0)\n",
      "Requirement already satisfied: pyzmq>=13 in /opt/conda/lib/python3.7/site-packages (from jupyter-client->ipykernel>=5.1.2; python_version >= \"3.4\"->pydeck>=0.1.dev5->streamlit->simpletransformers) (19.0.0)\n",
      "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.7/site-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.6.0)\n",
      "Requirement already satisfied: parso>=0.5.2 in /opt/conda/lib/python3.7/site-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.2)\n",
      "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.1.9)\n",
      "Requirement already satisfied: nbconvert in /opt/conda/lib/python3.7/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.6.1)\n",
      "Requirement already satisfied: terminado>=0.8.1 in /opt/conda/lib/python3.7/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.3)\n",
      "Requirement already satisfied: Send2Trash in /opt/conda/lib/python3.7/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.5.0)\n",
      "Requirement already satisfied: testpath in /opt/conda/lib/python3.7/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.4.4)\n",
      "Requirement already satisfied: defusedxml in /opt/conda/lib/python3.7/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.6.0)\n",
      "Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.7/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.4.2)\n",
      "Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.7/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.4)\n",
      "Requirement already satisfied: bleach in /opt/conda/lib/python3.7/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.1.4)\n",
      "Requirement already satisfied: webencodings in /opt/conda/lib/python3.7/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.1)\n",
      "Building wheels for collected packages: seqeval\n",
      "  Building wheel for seqeval (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for seqeval: filename=seqeval-0.0.12-py3-none-any.whl size=7423 sha256=bcf7ad5bbc0caa44e784add6bc1cbbbfde92a010e60e5511a248037de8656a35\n",
      "  Stored in directory: /root/.cache/pip/wheels/dc/cc/62/a3b81f92d35a80e39eb9b2a9d8b31abac54c02b21b2d466edc\n",
      "Successfully built seqeval\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing collected packages: seqeval, tqdm, base58, ipykernel, pydeck, cachetools, enum-compat, astor, validators, streamlit, transformers, simpletransformers\n",
      "  Attempting uninstall: tqdm\n",
      "    Found existing installation: tqdm 4.45.0\n",
      "    Uninstalling tqdm-4.45.0:\n",
      "      Successfully uninstalled tqdm-4.45.0\n",
      "  Attempting uninstall: ipykernel\n",
      "    Found existing installation: ipykernel 5.1.1\n",
      "    Uninstalling ipykernel-5.1.1:\n",
      "      Successfully uninstalled ipykernel-5.1.1\n",
      "  Attempting uninstall: cachetools\n",
      "    Found existing installation: cachetools 3.1.1\n",
      "    Uninstalling cachetools-3.1.1:\n",
      "      Successfully uninstalled cachetools-3.1.1\n",
      "  Attempting uninstall: transformers\n",
      "    Found existing installation: transformers 2.11.0\n",
      "    Uninstalling transformers-2.11.0:\n",
      "      Successfully uninstalled transformers-2.11.0\n",
      "\u001b[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.\n",
      "\n",
      "We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.\n",
      "\n",
      "transformers 3.1.0 requires tokenizers==0.8.1.rc2, but you'll have tokenizers 0.7.0 which is incompatible.\n",
      "jupyterlab-git 0.10.0 requires nbdime<2.0.0,>=1.1.0, but you'll have nbdime 2.0.0 which is incompatible.\n",
      "allennlp 1.0.0 requires transformers<2.12,>=2.9, but you'll have transformers 3.1.0 which is incompatible.\u001b[0m\n",
      "Successfully installed astor-0.8.1 base58-2.0.1 cachetools-4.1.1 enum-compat-0.0.3 ipykernel-5.3.4 pydeck-0.5.0b1 seqeval-0.0.12 simpletransformers-0.47.7 streamlit-0.66.0 tqdm-4.48.2 transformers-3.1.0 validators-0.18.1\n",
      "Found existing installation: tokenizers 0.7.0\n",
      "Uninstalling tokenizers-0.7.0:\n",
      "  Successfully uninstalled tokenizers-0.7.0\n",
      "Requirement already satisfied: transformers in /opt/conda/lib/python3.7/site-packages (3.1.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers) (3.0.10)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers) (2020.4.4)\n",
      "Collecting tokenizers==0.8.1.rc2\n",
      "  Downloading tokenizers-0.8.1rc2-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)\n",
      "\u001b[K     |████████████████████████████████| 3.0 MB 2.8 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers) (2.23.0)\n",
      "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers) (4.48.2)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from transformers) (1.18.5)\n",
      "Requirement already satisfied: sentencepiece!=0.1.92 in /opt/conda/lib/python3.7/site-packages (from transformers) (0.1.91)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers) (20.1)\n",
      "Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers) (0.0.43)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2020.6.20)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2.9)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from packaging->transformers) (1.14.0)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->transformers) (2.4.7)\n",
      "Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (7.1.1)\n",
      "Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (0.14.1)\n",
      "Installing collected packages: tokenizers\n",
      "\u001b[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.\n",
      "\n",
      "We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.\n",
      "\n",
      "allennlp 1.0.0 requires transformers<2.12,>=2.9, but you'll have transformers 3.1.0 which is incompatible.\u001b[0m\n",
      "Successfully installed tokenizers-0.8.1rc2\n"
     ]
    }
   ],
   "source": [
    "!pip install simpletransformers\n",
    "!pip uninstall -y tokenizers\n",
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_seed(seed_value):\n",
    "    import random \n",
    "    random.seed(seed_value)  \n",
    "    import numpy as np\n",
    "    np.random.seed(seed_value)  \n",
    "    import torch\n",
    "    torch.manual_seed(seed_value)  \n",
    "    \n",
    "    if torch.cuda.is_available(): \n",
    "        torch.cuda.manual_seed(seed_value)\n",
    "        torch.cuda.manual_seed_all(seed_value)  \n",
    "        torch.backends.cudnn.deterministic = True   \n",
    "        torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import gc\n",
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "from sklearn.metrics import accuracy_score, log_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.drop('Text_ID', axis=1, inplace=True)\n",
    "test.drop('Text_ID', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Product_Description</th>\n",
       "      <th>Product_Type</th>\n",
       "      <th>Sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1523</th>\n",
       "      <td>RT @mention ÷¼ Happy Woman's Day! Make love, ...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3868</th>\n",
       "      <td>RT @mention RT @mention Google to Launch Major...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4512</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4777</th>\n",
       "      <td>#SXSW is just starting, #CTIA is around the co...</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4991</th>\n",
       "      <td>Oh. My. God. The #SXSW app for iPad is pure, u...</td>\n",
       "      <td>7</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5436</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5480</th>\n",
       "      <td>Counting down the days to #sxsw plus strong Ca...</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5745</th>\n",
       "      <td>RT @mention ÷¼ GO BEYOND BORDERS! ÷_ {link} ...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5858</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Product_Description  Product_Type  \\\n",
       "1523  RT @mention ÷¼ Happy Woman's Day! Make love, ...             9   \n",
       "3868  RT @mention RT @mention Google to Launch Major...             9   \n",
       "4512  RT @mention Google to Launch Major New Social ...             9   \n",
       "4777  #SXSW is just starting, #CTIA is around the co...             0   \n",
       "4991  Oh. My. God. The #SXSW app for iPad is pure, u...             7   \n",
       "5436  RT @mention Marissa Mayer: Google Will Connect...             9   \n",
       "5480  Counting down the days to #sxsw plus strong Ca...             2   \n",
       "5745  RT @mention ÷¼ GO BEYOND BORDERS! ÷_ {link} ...             9   \n",
       "5858  RT @mention Google to Launch Major New Social ...             9   \n",
       "\n",
       "      Sentiment  \n",
       "1523          2  \n",
       "3868          2  \n",
       "4512          2  \n",
       "4777          3  \n",
       "4991          3  \n",
       "5436          2  \n",
       "5480          3  \n",
       "5745          2  \n",
       "5858          2  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[train.duplicated()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(train.shape)\n",
    "# train = train.drop_duplicates()\n",
    "# print(train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Product_Description</th>\n",
       "      <th>Product_Type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1436</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2457</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Product_Description  Product_Type\n",
       "1436  RT @mention Google to Launch Major New Social ...             9\n",
       "2457  RT @mention Marissa Mayer: Google Will Connect...             3"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[test.duplicated()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Product_Description</th>\n",
       "      <th>Product_Type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>Before It Even Begins, Apple Wins #SXSW {link}</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>266</th>\n",
       "      <td>Need to buy an iPad2 while I'm in Austin at #s...</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>303</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>317</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>586</th>\n",
       "      <td>Marissa Mayer: Google Will Connect the Digital...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876</th>\n",
       "      <td>Google to Launch Major New Social Network Call...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>906</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1025</th>\n",
       "      <td>I just noticed DST is coming this weekend. How...</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1094</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1395</th>\n",
       "      <td>Really enjoying the changes in Gowalla 3.0 for...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1436</th>\n",
       "      <td>RT @mention Google to Launch Major New Social ...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2419</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2457</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2518</th>\n",
       "      <td>Win free iPad 2 from webdoc.com #sxsw RT</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2692</th>\n",
       "      <td>RT @mention Marissa Mayer: Google Will Connect...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Product_Description  Product_Type\n",
       "103      Before It Even Begins, Apple Wins #SXSW {link}             2\n",
       "266   Need to buy an iPad2 while I'm in Austin at #s...             6\n",
       "303   RT @mention Marissa Mayer: Google Will Connect...             9\n",
       "317   RT @mention Google to Launch Major New Social ...             9\n",
       "586   Marissa Mayer: Google Will Connect the Digital...             9\n",
       "876   Google to Launch Major New Social Network Call...             9\n",
       "906   RT @mention Google to Launch Major New Social ...             9\n",
       "1025  I just noticed DST is coming this weekend. How...             8\n",
       "1094  RT @mention Marissa Mayer: Google Will Connect...             3\n",
       "1395  Really enjoying the changes in Gowalla 3.0 for...             1\n",
       "1436  RT @mention Google to Launch Major New Social ...             9\n",
       "2419  RT @mention Marissa Mayer: Google Will Connect...             9\n",
       "2457  RT @mention Marissa Mayer: Google Will Connect...             3\n",
       "2518           Win free iPad 2 from webdoc.com #sxsw RT             9\n",
       "2692  RT @mention Marissa Mayer: Google Will Connect...             3"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[test['Product_Description'].isin(train['Product_Description'].values.tolist())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = train.rename(columns={'Product_Description' : 'text', 'Sentiment': 'labels'})\n",
    "test = test.rename(columns={'Product_Description' : 'text'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train['text'] = train['text'].astype(str) + ' ' + train['Product_Type'].astype(str)\n",
    "test['text'] = test['text'].astype(str) + ' ' + test['Product_Type'].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.drop('Product_Type', axis=1, inplace=True)\n",
    "test.drop('Product_Type', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "count    6364.000000\n",
      "mean       18.730830\n",
      "std         4.965762\n",
      "min         3.000000\n",
      "25%        15.000000\n",
      "50%        19.000000\n",
      "75%        22.000000\n",
      "max        34.000000\n",
      "Name: text, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "print(train['text'].apply(lambda x: len(x.split())).describe())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "count    2728.000000\n",
      "mean       18.846041\n",
      "std         4.949445\n",
      "min         5.000000\n",
      "25%        15.000000\n",
      "50%        19.000000\n",
      "75%        23.000000\n",
      "max        32.000000\n",
      "Name: text, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "print(test['text'].apply(lambda x: len(x.split())).describe())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n"
     ]
    }
   ],
   "source": [
    "from simpletransformers.classification import ClassificationModel\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "from sklearn.metrics import accuracy_score, log_loss\n",
    "from scipy.special import softmax\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = {'train_batch_size': 128, \n",
    "              'reprocess_input_data': True,\n",
    "              'overwrite_output_dir': True,\n",
    "              'fp16': False,\n",
    "              'do_lower_case': False,\n",
    "              'num_train_epochs': 7,\n",
    "              'max_seq_length': 34,\n",
    "              'regression': False,\n",
    "              'manual_seed': 1994,\n",
    "              \"learning_rate\": 1e-5,\n",
    "              #'weight_decay': 0.01,\n",
    "              \"save_eval_checkpoints\": False,\n",
    "              \"save_model_every_epoch\": False,\n",
    "              'no_cache':True,\n",
    "              \"silent\": False,\n",
    "              \"use_early_stopping\": True,\n",
    "              #\"early_stopping_delta\": 0.015,\n",
    "              \"early_stopping_metric\": \"mcc\",\n",
    "              \"early_stopping_metric_minimize\": False,\n",
    "              \"early_stopping_patience\": 3,\n",
    "              \"evaluate_during_training_steps\": 1000\n",
    "              }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model():\n",
    "    model = ClassificationModel('roberta', 'roberta-base', use_cuda=True, num_labels=4, args=model_args)                            \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3595d5246c0241a89615c96ccc46cccf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0393f4d75bab43f297a11c2aa27da770",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "045adc9531e04e8c8b868c1068c8bfb3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5045bf1d966e400fa18377208b1795e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50522f48ff124649bf47f4532264bcdb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5091.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0368b678cb7d40b882167a05ab8a6511",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=7.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8af6081fd6084bd48ccc3a1259c6c096",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb2c0e95a61d40f7a765192c5e0102f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "188a5d7b44684ed7bd32c9d5ae0d6f19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e5847f66c2a459d909f439c6c9b677b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8e62b6bf2834c219223f47c31ffaffa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2e6269d62ce4bf88182ba97330aac2a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "74389e9c91894dbea10d93a68d525602",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5cbb9d149e45427d8876a69391473695",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1273.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8280d7f03dd64499a5b40f35d66fb273",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=160.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Log_Loss: 0.5512822086777237\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71c78e1e61df42708ec32fb059618edd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2728.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b54c1025b41945deb2aef3f9149ec702",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=341.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9953c7f262ad42a1b1e772b8c2727bba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5091.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e408452769e943daba8c9619a769faf3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=7.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e062be77d4b499abf5c7a8615c918a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d9a86f60d1d4e899fff3c9076701bd8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a369337e7a9244989d6e8dbcea033412",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d381a3204337471f97f3615526128f37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40b0f0d049ae49f3b1f26670cd680be5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb5363dcab9f44059620a6745fcbadb7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12bdd09cffad411e8cf7d949e31805b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8310e66df44547f0a32eb70898c7763a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1273.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f73bab836c404d3eb367ef0119da4de7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=160.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Log_Loss: 0.574208609082349\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a531a108cf848be90b2d9745c5b6a25",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2728.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88be632e7c714013b15e8c0e94b21de6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=341.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f5736fca6a14cc2bc18a533be942f06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5091.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd05d6d18be84849b3c7b6035a215a30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=7.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20e4b60a010c4ae7a53748e34babee50",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a3ec1b98ccc44798f33e947a7310f63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40205666242443a7ae5c1535927505e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "420fca4be83e472a912a1405d96fbeed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "610993d092364bf6af81e136765ab140",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06dd88cdf5e949f3a0f6505718492e55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6b350160e4c41a9aa72ba24f1d3fd5b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4590abc77464fff8ae7e7a10ac186bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1273.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8976889437e46cda637f3b30bbe92e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=160.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Log_Loss: 0.6119374637956565\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "964681305bab4da88128f6006857da0d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2728.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3437b9a9a7594bafb5d7283b6a0c30aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=341.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6205cde133224ee1a84acf39c0612e71",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5091.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a73502b89db14a6bae14b5098bde7c0d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=7.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22574469da1645208a3af784e78f3e8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aeafa6513fe54c26bde1077dd431a30d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "067187b7f9c241e381c13b95ce58d28c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61deeaec31ec4620bb837d814416b1c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f5d57c526924fd294c17b35bae02a19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf66b66bd8f54799990ab208c523e981",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76d380907d8f4927b76275ccef445809",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "227473719f5249b4bb510222cbb0ea7b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1273.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc42f828f62944c8b0003b3558f43bcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=160.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Log_Loss: 0.6194366897424889\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d87ccfa05e74e2fb6f5524a27586f16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2728.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8b23f33d02340d7b7ccdfd5c6b8a07d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=341.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4288f1aac97411ebab0e2ffc4a7c838",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5092.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97d05d540e9d4a6395870ba8581db9d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=7.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e52c1c243834c14adf3a98d2a0897c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a63b315251e140a68e94361df8a99c19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "334625bc50c74856aa3074308fe7de4c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ba1d0fbfa0142f5860fa3b3d247c4b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90656caddccd4dfc8b8c9d455f174628",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31c48ef63ad34c559c2d3ab4b72ae872",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ac885394b52405a85be9471b89a523a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 7', max=40.0, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1c7a2933829f4717867a7a0127c6c167",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1272.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd620f1a92da4144b83332fcc9c61a7d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=159.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Log_Loss: 0.5776664138472868\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e23c6ede41414d2aa9fdc6186075d6fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2728.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb88e319c04d415487be4130d3cbfbf9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=341.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "err=[]\n",
    "y_pred_tot=[]\n",
    "i=1\n",
    "\n",
    "fold=StratifiedKFold(n_splits=5, shuffle=True, random_state=1994)\n",
    "\n",
    "for train_index, test_index in fold.split(train, train['labels']):\n",
    "    train1_trn, train1_val = train.iloc[train_index], train.iloc[test_index]\n",
    "    model = get_model()\n",
    "    gc.collect()\n",
    "\n",
    "    model.train_model(train1_trn)\n",
    "    score, raw_outputs_val, wrong_preds = model.eval_model(train1_val) \n",
    "    raw_outputs_val = softmax(raw_outputs_val, axis=1) \n",
    "    print('Log_Loss:', log_loss(train1_val['labels'], raw_outputs_val))\n",
    "    err.append(log_loss(train1_val['labels'], raw_outputs_val))\n",
    "    predictions, raw_outputs_test = model.predict(test['text'])\n",
    "    raw_outputs_test = softmax(raw_outputs_test, axis=1) \n",
    "    y_pred_tot.append(raw_outputs_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.586906277029101"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(err, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.mean(y_pred_tot, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     0    1    2    3\n",
       "0  0.0  0.0  0.0  0.0\n",
       "1  0.0  0.0  0.0  0.0\n",
       "2  0.0  0.0  0.0  0.0\n",
       "3  0.0  0.0  0.0  0.0\n",
       "4  0.0  0.0  0.0  0.0"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samp.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.006008</td>\n",
       "      <td>0.006244</td>\n",
       "      <td>0.711419</td>\n",
       "      <td>0.276329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.016405</td>\n",
       "      <td>0.023132</td>\n",
       "      <td>0.928492</td>\n",
       "      <td>0.031971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.014886</td>\n",
       "      <td>0.065645</td>\n",
       "      <td>0.354355</td>\n",
       "      <td>0.565115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.005009</td>\n",
       "      <td>0.011859</td>\n",
       "      <td>0.067244</td>\n",
       "      <td>0.915888</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.005622</td>\n",
       "      <td>0.003685</td>\n",
       "      <td>0.931481</td>\n",
       "      <td>0.059212</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2         3\n",
       "0  0.006008  0.006244  0.711419  0.276329\n",
       "1  0.016405  0.023132  0.928492  0.031971\n",
       "2  0.014886  0.065645  0.354355  0.565115\n",
       "3  0.005009  0.011859  0.067244  0.915888\n",
       "4  0.005622  0.003685  0.931481  0.059212"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samp[['0','1','2','3']] = y_pred\n",
    "samp.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv(\"/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Train.csv\")\n",
    "test = pd.read_csv(\"/kaggle/input/machine-hack-product-sentiment-classification/Participants_Data/Test.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Text_ID', 'Product_Description', 'Product_Type', 'Sentiment'], dtype='object')"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ids = []\n",
    "force_senti = []\n",
    "for index,row in test.iterrows():\n",
    "    temp = train[train['Product_Description'] == row['Product_Description']]\n",
    "    temp = temp[temp['Product_Type'] == row['Product_Type']]\n",
    "    if temp.shape[0] > 0:\n",
    "        force_senti.append(list(set(temp['Sentiment'].tolist()))[0])\n",
    "        test_ids.append(index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "dic = {0:[1,0,0,0],1:[0,1,0,0],2:[0,0,1,0],3:[0,0,0,1]}\n",
    "for x,y in zip(test_ids,force_senti):\n",
    "    #print(s1.iloc[x])\n",
    "    target = dic[y]\n",
    "    samp.iloc[x] = target\n",
    "    #print(s1.iloc[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.006008</td>\n",
       "      <td>0.006244</td>\n",
       "      <td>0.711419</td>\n",
       "      <td>0.276329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.016405</td>\n",
       "      <td>0.023132</td>\n",
       "      <td>0.928492</td>\n",
       "      <td>0.031971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.014886</td>\n",
       "      <td>0.065645</td>\n",
       "      <td>0.354355</td>\n",
       "      <td>0.565115</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2         3\n",
       "0  0.006008  0.006244  0.711419  0.276329\n",
       "1  0.016405  0.023132  0.928492  0.031971\n",
       "2  0.014886  0.065645  0.354355  0.565115"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samp.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "samp.to_csv('Sub_v0.3.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
